'''
Author: SlytherinGe
LastEditTime: 2021-05-11 17:13:49
'''
if __name__ == '__main__':
    import voc_label_utility as VL
else:
    import dataset_utilty.voc_label_utility as VL

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

def get_gt_width_height_in_one_img(label_root):

    dict_gt_width_height = []
    anno_info, anno_obj = VL.voc_label_preprocess(label_root)
    for obj in anno_obj:
        dict_obj = VL.voc_object_xml_element_resolver(obj)
        ob_xmin, ob_ymin, ob_xmax, ob_ymax = \
        int(dict_obj['bndbox']['xmin']), int(dict_obj['bndbox']['ymin']), int(dict_obj['bndbox']['xmax']), int(dict_obj['bndbox']['ymax'])
        dict_gt_width_height.append({'width': ob_xmax - ob_xmin, 'height': ob_ymax - ob_ymin})
 
    return dict_gt_width_height

def get_biggest_gt_size(dataset_root):
    list_gt_w_h = []
    listdir = os.listdir(dataset_root)
    for father in listdir:
        ann_file_path = os.path.join(ROOT, father)
        gt_w_h = get_gt_width_height_in_one_img(ann_file_path)
        list_gt_w_h.extend(gt_w_h)  
    most_high = list_gt_w_h[0]
    most_wide = list_gt_w_h[0]
    for w_h in list_gt_w_h:
        if w_h['height'] > most_high['height']:
            most_high = w_h
        if w_h['width'] > most_wide['width']:
            most_wide = w_h
    return (most_wide, most_high) 

def plot_dataset_gt_shape_seaborn(dataset_root):
    list_gt_w_h = []
    listdir = os.listdir(dataset_root)
    for father in listdir:
        ann_file_path = os.path.join(ROOT, father)
        gt_w_h = get_gt_width_height_in_one_img(ann_file_path)
        for w_h in gt_w_h:
            list_gt_w_h.append([w_h['width'], w_h['height']])  

    gt_data = np.array(list_gt_w_h)
    sns.set_theme()
    sns.set_style('white')
    g = sns.JointGrid(x=gt_data[:,0], y=gt_data[:,1])
    g = g.plot_joint(plt.scatter, color='green', s=25, linewidth=0, alpha=0.2)
    g.plot_marginals(sns.distplot, kde=True, hist_kws={'color':'g','edgecolor':'k'})
    g.set_axis_labels(xlabel='width', ylabel='height')  
    plt.show() 

if __name__ == '__main__':
    import os
    ROOT = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/Annotations/'
    plot_dataset_gt_shape_seaborn(ROOT)


# if __name__ == '__main__':
#     import os
#     import numpy as np
#     import matplotlib.pyplot as plt
    '''
    AIRSarShip船只大小统计
    '''
    # list_gt_w_h = []
    # ROOT = "/media/gejunyao/Disk1/Datasets/CAS-SARship/SARShip/"
    # listdir = os.listdir(ROOT)
    # for father in listdir:
    #     sons = os.listdir(os.path.join(ROOT, father))
    #     for son in sons:
    #         if son.endswith('.xml'):
    #             ann_file_path = os.path.join(ROOT, father, son)
    #             gt_w_h = get_gt_width_height_in_one_img(ann_file_path)
    #             list_gt_w_h.extend(gt_w_h)
    
    # list_size = []
    # for w_h in list_gt_w_h:
    #     list_size.append(w_h['width'] * w_h['height'])
    # plt.figure()
    # plt.hist(np.asarray(list_size), bins=100, edgecolor="black", density=True)
    # plt.title('ship size distribution in AIRSarShip')
    # plt.xticks(np.linspace(0,80000,21))  # x标记step设置为2
    # plt.xlabel('area(pixels)')
    # plt.ylabel('percentage')
    # plt.show()
    # '''
    # CAS-OpenSARShip船只大小统计
    # '''
    # list_gt_w_h = []
    # ROOT = "/media/gejunyao/Disk1/Datasets/SSDD/Annotations/"
    # listdir = os.listdir(ROOT)
    # for father in listdir:
    #     ann_file_path = os.path.join(ROOT, father)
    #     gt_w_h = get_gt_width_height_in_one_img(ann_file_path)
    #     list_gt_w_h.extend(gt_w_h)
    
    # list_size = []
    # for w_h in list_gt_w_h:
    #     list_size.append(w_h['width'] * w_h['height'])
    # print(max(list_size))
    # plt.figure()
    # plt.hist(np.asarray(list_size), bins=100, edgecolor="black", density=True)
    # plt.title('ship size distribution in SSDD')
    # # plt.xticks(np.linspace(0,16000,21))  # x标记step设置为2
    # plt.xlabel('area(pixels)')
    # plt.ylabel('percentage')
    # plt.show()
